{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "f71bf07d",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-01-24T02:01:55.525012Z",
     "start_time": "2024-01-24T02:01:55.515859Z"
    }
   },
   "outputs": [],
   "source": [
    "import akshare as ak\n",
    "import pandas as pd\n",
    "from __future__ import print_function\n",
    "import datetime \n",
    "import numpy as np\n",
    "import sklearn\n",
    "import pandas_datareader.data as web\n",
    "from sklearn.ensemble import RandomForestClassifier \n",
    "from sklearn.linear_model import LogisticRegression \n",
    "from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA\n",
    "from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis as QDA\n",
    "from sklearn.metrics import confusion_matrix \n",
    "from sklearn.svm import LinearSVC, SVC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "id": "a279a288",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-01-24T02:58:04.163194Z",
     "start_time": "2024-01-24T02:58:03.821144Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                    日期       开盘       收盘       最高       最低        成交量  \\\n",
      "日期                                                                      \n",
      "2022-01-04  2022-01-04  3649.15  3632.33  3651.89  3610.09  405027768   \n",
      "2022-01-05  2022-01-05  3628.26  3595.18  3628.26  3583.47  423902029   \n",
      "2022-01-06  2022-01-06  3581.22  3586.08  3594.49  3559.88  371540544   \n",
      "\n",
      "                     成交额    振幅   涨跌幅    涨跌额   换手率  \n",
      "日期                                                 \n",
      "2022-01-04  5.102511e+11  1.15 -0.20  -7.45  0.90  \n",
      "2022-01-05  5.389636e+11  1.23 -1.02 -37.15  0.94  \n",
      "2022-01-06  4.742843e+11  0.96 -0.25  -9.10  0.82  \n"
     ]
    }
   ],
   "source": [
    "## 1. 获取行情的例子\n",
    "index_zh_a_hist_df = ak.index_zh_a_hist(symbol=\"000001\", period=\"daily\", start_date=\"20220101\", end_date='20231231')\n",
    "index_zh_a_hist_df.set_index(pd.to_datetime(index_zh_a_hist_df[\"日期\"],format='%Y-%m-%d'),inplace=True)\n",
    "print(index_zh_a_hist_df[0:3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "id": "1f527391",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-01-24T06:18:56.008004Z",
     "start_time": "2024-01-24T06:18:56.002502Z"
    }
   },
   "outputs": [],
   "source": [
    "def create_lagged_series(sym, startdate, enddate, lags=5): \n",
    "    # Obtain stock information \n",
    "    ts = ak.index_zh_a_hist(symbol=sym, period=\"daily\", start_date=startdate, end_date=enddate)\n",
    "    ts.set_index(pd.to_datetime(ts[\"日期\"],format='%Y-%m-%d'),inplace=True)\n",
    "    # Create the new lagged DataFrame\n",
    "    tslag = pd.DataFrame(index=ts.index)\n",
    "    tslag[\"Today\"] = ts[\"收盘\"]\n",
    "    tslag[\"Volume\"] = ts[\"成交量\"]\n",
    "    # Create the shifted lag series of prior trading period close values\n",
    "    for i in range(0, lags):\n",
    "        tslag[\"Lag%s\" % str(i+1)] = ts[\"收盘\"].shift(i+1)\n",
    "    # Create the returns DataFrame\n",
    "    tsret = pd.DataFrame(index=tslag.index) \n",
    "    tsret[\"Volume\"] = tslag[\"Volume\"]\n",
    "    tsret[\"Today\"] = tslag[\"Today\"].pct_change()*100.0\n",
    "    # If any of the values of percentage returns equal zero, set them to # a small number (stops issues with QDA model in Scikit-Learn)\n",
    "    # enumerate() 函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列，同时列出数据和数据\n",
    "    for i,x in enumerate(tsret[\"Today\"]):\n",
    "        if (abs(x) < 0.0001): \n",
    "            tsret[\"Today\"][i] = 0.0001\n",
    "    # Create the lagged percentage returns columns\n",
    "    for i in range(0, lags):\n",
    "        tsret[\"Lag%s\" % str(i+1)] = tslag[\"Lag%s\" % str(i+1)].pct_change()*100.0\n",
    "    # Create the \"Direction\" column (+1 or -1) indicating an up/down day\n",
    "    # sign()是Python的Numpy中的取数字符号（数字前的正负号）的函数\n",
    "    tsret[\"Direction\"] = np.sign(tsret[\"Today\"]) \n",
    "    startdatepd = pd.to_datetime(startdate,format='%Y%m%d')\n",
    "    tsret = tsret[tsret.index >= startdatepd]\n",
    "    \n",
    "    return tsret"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "id": "bb22a4e7",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-01-24T08:21:06.498232Z",
     "start_time": "2024-01-24T08:21:03.642831Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----------->\n",
      "               Volume     Today      Lag1      Lag2  Direction\n",
      "日期                                                            \n",
      "2020-01-02  292470208       NaN       NaN       NaN        NaN\n",
      "2020-01-03  261496668 -0.045702       NaN       NaN       -1.0\n",
      "2020-01-06  312575841 -0.012322 -0.045702       NaN       -1.0\n",
      "2020-01-07  276583112  0.693712 -0.012322 -0.045702        1.0\n",
      "2020-01-08  297872553 -1.221013  0.693712 -0.012322       -1.0\n",
      "...               ...       ...       ...       ...        ...\n",
      "2023-12-25  229814178  0.138261 -0.134649  0.571998        1.0\n",
      "2023-12-26  228140855 -0.682813  0.138261 -0.134649       -1.0\n",
      "2023-12-27  247900882  0.542623 -0.682813  0.138261        1.0\n",
      "2023-12-28  339213116  1.375484  0.542623 -0.682813        1.0\n",
      "2023-12-29  290672687  0.684672  1.375484  0.542623        1.0\n",
      "\n",
      "[970 rows x 5 columns]\n",
      "Hit Rates/Confusion Matrices:\n",
      "\n",
      "LR:\n",
      "0.483\n",
      "[[ 13  14]\n",
      " [111 104]]\n",
      "\n",
      "LDA:\n",
      "0.483\n",
      "[[ 13  14]\n",
      " [111 104]]\n",
      "\n",
      "QDA:\n",
      "0.459\n",
      "[[101 108]\n",
      " [ 23  10]]\n",
      "\n",
      "LSVC:\n",
      "0.483\n",
      "[[ 13  14]\n",
      " [111 104]]\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/neil/anaconda3/lib/python3.10/site-packages/sklearn/svm/_classes.py:31: FutureWarning: The default value of `dual` will change from `True` to `'auto'` in 1.5. Set the value of `dual` explicitly to suppress the warning.\n",
      "  warnings.warn(\n",
      "/Users/neil/anaconda3/lib/python3.10/site-packages/sklearn/svm/_base.py:1237: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "RSVM:\n",
      "0.455\n",
      "[[90 98]\n",
      " [34 20]]\n",
      "\n",
      "RF:\n",
      "0.463\n",
      "[[63 69]\n",
      " [61 49]]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "\n",
    "# Create a lagged series\n",
    "lags_val = 2\n",
    "snpret = create_lagged_series(\"000001\", \"20200101\", \"20231231\", lags=lags_val)\n",
    "print(\"----------->\")\n",
    "print(snpret)\n",
    "# 过去2天的涨跌幅和当日的涨跌结果作为训练集合的输入和输出\n",
    "X = snpret[[\"Lag1\",\"Lag2\"]]\n",
    "y = snpret[\"Direction\"]\n",
    "# 前3年的数据作为训练集，对最近一年进行预测校验\n",
    "start_test = datetime.datetime(2023,1,1)\n",
    "X_train = X[X.index < start_test][lag_val + 1:]\n",
    "X_test = X[X.index >= start_test]\n",
    "y_train = y[y.index < start_test][lag_val + 1:]\n",
    "y_test = y[y.index >= start_test]\n",
    "\n",
    "print(\"Hit Rates/Confusion Matrices:\\n\") \n",
    "# 训练模型，作为一个列表，包含常用的模型\n",
    "models = [(\"LR\", LogisticRegression()),\n",
    "    (\"LDA\", LDA()),(\"QDA\", QDA()),(\"LSVC\", LinearSVC()),\n",
    "    (\"RSVM\", SVC(\n",
    "    C=1000000.0, cache_size=200, class_weight=None, coef0=0.0, degree=3, gamma=0.0001, kernel='rbf',\n",
    "    max_iter=-1, probability=False, random_state=None,\n",
    "      shrinking=True, tol=0.001, verbose=False)\n",
    "    ),\n",
    "    (\"RF\", RandomForestClassifier(\n",
    "      n_estimators=1000, criterion='gini',\n",
    "      max_depth=None, min_samples_split=2,\n",
    "    min_samples_leaf=1, \n",
    "                bootstrap=True, oob_score=False, n_jobs=1,\n",
    "                random_state=None, verbose=0)\n",
    "    )]\n",
    "# Iterate through the models\n",
    "for m in models:\n",
    "    # Train each of the models on the training set\n",
    "    m[1].fit(X_train, y_train)\n",
    " \n",
    "    # Make an array of predictions on the test set\n",
    "    pred = m[1].predict(X_test)\n",
    "    # Output the hit-rate and the confusion matrix for each model\n",
    "    print(\"%s:\\n%0.3f\" % (m[0], m[1].score(X_test, y_test))) \n",
    "    print(\"%s\\n\" % confusion_matrix(pred, y_test))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c272214",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
